import tensorflow as tf
from .network import Network
from .cfg.config import cfg
from .ctpn_train import ctpn_train
from .ctpn_test import ctpn_detector

def get_network(name):
    """Get a network by name."""
    if name.split('_')[0] == 'ctpn':
        if name.split('_')[1] == 'test':
           return ctpn_detector()
        elif name.split('_')[1] == 'train':
           return ctpn_train()
        else:
           raise KeyError('Unknown dataset: {}'.format(name))
    else:
        raise KeyError('Unknown dataset: {}'.format(name))
        
if __name__ == "__main__":

    detector = get_network("ctpn_test")
    
    print(detector)
